import sys

from tqdm import tqdm
import torch

from train_multi_GPU.multi_train_utils.distributed_utils import reduce_value, is_main_process


def train_one_epoch(model, optimizer, data_loader, device, epoch):
    '''
    功能: 一个epoch的模型训练学习
    输入：
        model: 待训练的模型
        optimizer: 优化器
        data_loader: DataLoader数据
        device: 设备
        epoch: 第几轮
        wandb: 监控【在每个batch输出】
    输出：
        平均损失
    '''
    # 模型训练模式
    model.train()
    # 损失函数
    loss_function = torch.nn.CrossEntropyLoss()
    # 平均损失。初始值为0
    mean_loss = torch.zeros(1).to(device)
    # 清空优化器的梯度信息
    optimizer.zero_grad()

    # 在进程0中打印训练进度（仅在主进程中打印进度条）
    if is_main_process():
        # 进度条封装data_loader
        data_loader = tqdm(data_loader, file=sys.stdout)

    for step, data in enumerate(data_loader):
        # step: 第几个batch
        images, labels = data

        # 预测值
        pred = model(images.to(device))
        # 计算损失值：预测值，真实值
        loss = loss_function(pred, labels.to(device))
        # 梯度反向传播
        loss.backward()
        # 对多GPU的loss进行求和后的均值
        loss = reduce_value(loss, average=True)

        # update mean losses 计算所有batch的平均损失
        mean_loss = (mean_loss * step + loss.detach()) / (step + 1)  

        # 在进程0中打印平均loss
        if is_main_process():
            # data_loader封装到了tqdm里面，就可以修改进度条的描述内容
            data_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))

        if not torch.isfinite(loss):
            print('WARNING: non-finite loss, ending training ', loss)
            sys.exit(1)

        # 每次batch更新参数，并清空梯度
        optimizer.step()
        optimizer.zero_grad()

    # 等待所有进程计算完毕
    if device != torch.device("cpu"):
        torch.cuda.synchronize(device)

    return mean_loss.item()


@torch.no_grad()
def evaluate(model, data_loader, device):
    '''
    功能: 评估模型效果
    输入:
        model: 训练好的模型
        data_loader: DataLoader数据
        device: 设备
    输出:
        预测正确的样本个数
    '''
    # 进入验证模式
    model.eval()

    # 用于存储预测正确的样本个数
    sum_num = torch.zeros(1).to(device)

    # 在进程0中打印验证进度
    if is_main_process():
        data_loader = tqdm(data_loader, file=sys.stdout)

    for step, data in enumerate(data_loader):
        images, labels = data
        pred = model(images.to(device))
        pred = torch.max(pred, dim=1)[1]
        sum_num += torch.eq(pred, labels.to(device)).sum()

    # 等待所有进程计算完毕
    if device != torch.device("cpu"):
        torch.cuda.synchronize(device)

    # 对多GPU求和
    sum_num = reduce_value(sum_num, average=False)

    return sum_num.item()






